{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parameters\n",
    "learning_rate = 0.001\n",
    "training_epochs = 6\n",
    "batch_size = 600\n",
    "\n",
    "# Import MNIST data\n",
    "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "\n",
    "train_dataset = (\n",
    "    tf.data.Dataset.from_tensor_slices((tf.reshape(x_train, [-1, 784]), y_train))\n",
    "    .batch(batch_size)\n",
    "    .shuffle(1000)\n",
    ")\n",
    "\n",
    "train_dataset = (\n",
    "    train_dataset.map(lambda x, y:\n",
    "                      (tf.divide(tf.cast(x, tf.float32), 255.0),\n",
    "                       tf.reshape(tf.one_hot(y, 10), (-1, 10))))\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set model weights\n",
    "W = tf.Variable(tf.zeros([784, 10]))\n",
    "b = tf.Variable(tf.zeros([10]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> loss 195.05 acc 0.10\n",
      "=> loss 191.49 acc 0.57\n",
      "=> loss 188.84 acc 0.53\n",
      "=> loss 186.42 acc 0.57\n",
      "=> loss 182.66 acc 0.57\n",
      "=> loss 178.83 acc 0.67\n",
      "=> loss 178.17 acc 0.61\n",
      "=> loss 174.41 acc 0.62\n",
      "=> loss 171.08 acc 0.67\n",
      "=> loss 169.72 acc 0.71\n",
      "=> loss 165.83 acc 0.66\n",
      "=> loss 163.73 acc 0.71\n",
      "=> loss 166.10 acc 0.63\n",
      "=> loss 163.47 acc 0.68\n",
      "=> loss 155.01 acc 0.72\n",
      "=> loss 151.60 acc 0.70\n",
      "=> loss 153.95 acc 0.72\n",
      "=> loss 149.91 acc 0.69\n",
      "=> loss 143.68 acc 0.75\n",
      "=> loss 144.36 acc 0.74\n",
      "=> loss 146.78 acc 0.70\n",
      "=> loss 138.88 acc 0.75\n",
      "=> loss 143.94 acc 0.74\n",
      "=> loss 134.62 acc 0.76\n",
      "=> loss 131.31 acc 0.78\n",
      "=> loss 130.84 acc 0.77\n",
      "=> loss 128.80 acc 0.80\n",
      "=> loss 126.33 acc 0.78\n",
      "=> loss 122.29 acc 0.77\n",
      "=> loss 126.45 acc 0.79\n",
      "=> loss 120.56 acc 0.81\n",
      "=> loss 120.24 acc 0.77\n",
      "=> loss 123.38 acc 0.74\n",
      "=> loss 120.65 acc 0.76\n",
      "=> loss 116.62 acc 0.82\n",
      "=> loss 111.72 acc 0.81\n",
      "=> loss 103.49 acc 0.85\n",
      "=> loss 104.84 acc 0.79\n",
      "=> loss 109.73 acc 0.77\n",
      "=> loss 105.46 acc 0.80\n",
      "=> loss 107.44 acc 0.80\n",
      "=> loss 109.96 acc 0.78\n",
      "=> loss 94.47 acc 0.84\n",
      "=> loss 102.30 acc 0.80\n",
      "=> loss 104.75 acc 0.81\n",
      "=> loss 105.57 acc 0.78\n",
      "=> loss 100.41 acc 0.80\n",
      "=> loss 105.97 acc 0.76\n",
      "=> loss 93.61 acc 0.83\n",
      "=> loss 98.71 acc 0.80\n",
      "=> loss 88.41 acc 0.81\n",
      "=> loss 88.66 acc 0.84\n",
      "=> loss 88.69 acc 0.82\n",
      "=> loss 88.91 acc 0.84\n",
      "=> loss 89.90 acc 0.82\n",
      "=> loss 81.47 acc 0.86\n",
      "=> loss 87.65 acc 0.82\n",
      "=> loss 96.22 acc 0.75\n",
      "=> loss 88.73 acc 0.79\n",
      "=> loss 93.59 acc 0.76\n",
      "=> loss 68.55 acc 0.92\n",
      "=> loss 92.06 acc 0.77\n",
      "=> loss 91.29 acc 0.78\n",
      "=> loss 83.23 acc 0.84\n",
      "=> loss 88.66 acc 0.79\n",
      "=> loss 78.69 acc 0.85\n",
      "=> loss 85.99 acc 0.83\n",
      "=> loss 81.69 acc 0.84\n",
      "=> loss 83.03 acc 0.83\n",
      "=> loss 69.32 acc 0.87\n",
      "=> loss 72.08 acc 0.86\n",
      "=> loss 74.06 acc 0.84\n",
      "=> loss 78.92 acc 0.83\n",
      "=> loss 82.12 acc 0.79\n",
      "=> loss 81.89 acc 0.82\n",
      "=> loss 79.07 acc 0.83\n",
      "=> loss 70.92 acc 0.86\n",
      "=> loss 78.10 acc 0.81\n",
      "=> loss 72.35 acc 0.85\n",
      "=> loss 72.12 acc 0.86\n",
      "=> loss 72.27 acc 0.85\n",
      "=> loss 63.14 acc 0.87\n",
      "=> loss 83.00 acc 0.79\n",
      "=> loss 75.93 acc 0.80\n",
      "=> loss 69.23 acc 0.84\n",
      "=> loss 67.33 acc 0.86\n",
      "=> loss 74.30 acc 0.85\n",
      "=> loss 71.87 acc 0.85\n",
      "=> loss 73.17 acc 0.86\n",
      "=> loss 84.00 acc 0.80\n",
      "=> loss 78.82 acc 0.79\n",
      "=> loss 62.96 acc 0.86\n",
      "=> loss 76.37 acc 0.81\n",
      "=> loss 76.01 acc 0.80\n",
      "=> loss 62.87 acc 0.85\n",
      "=> loss 75.88 acc 0.82\n",
      "=> loss 65.08 acc 0.87\n",
      "=> loss 57.16 acc 0.90\n",
      "=> loss 64.22 acc 0.86\n",
      "=> loss 69.01 acc 0.85\n",
      "=> loss 74.83 acc 0.79\n",
      "=> loss 58.67 acc 0.87\n",
      "=> loss 71.85 acc 0.82\n",
      "=> loss 71.90 acc 0.85\n",
      "=> loss 56.35 acc 0.89\n",
      "=> loss 61.49 acc 0.86\n",
      "=> loss 67.84 acc 0.83\n",
      "=> loss 65.34 acc 0.84\n",
      "=> loss 65.05 acc 0.85\n",
      "=> loss 59.87 acc 0.89\n",
      "=> loss 64.02 acc 0.83\n",
      "=> loss 59.84 acc 0.88\n",
      "=> loss 73.06 acc 0.82\n",
      "=> loss 72.97 acc 0.82\n",
      "=> loss 61.39 acc 0.85\n",
      "=> loss 60.32 acc 0.86\n",
      "=> loss 66.19 acc 0.83\n",
      "=> loss 63.12 acc 0.85\n",
      "=> loss 55.07 acc 0.89\n",
      "=> loss 59.72 acc 0.86\n",
      "=> loss 65.85 acc 0.81\n",
      "=> loss 57.84 acc 0.86\n",
      "=> loss 68.17 acc 0.84\n",
      "=> loss 61.81 acc 0.83\n",
      "=> loss 55.14 acc 0.88\n",
      "=> loss 62.46 acc 0.85\n",
      "=> loss 55.20 acc 0.89\n",
      "=> loss 54.89 acc 0.88\n",
      "=> loss 52.97 acc 0.88\n",
      "=> loss 58.93 acc 0.86\n",
      "=> loss 50.85 acc 0.90\n",
      "=> loss 57.71 acc 0.86\n",
      "=> loss 65.83 acc 0.81\n",
      "=> loss 60.83 acc 0.85\n",
      "=> loss 54.85 acc 0.89\n",
      "=> loss 56.16 acc 0.84\n",
      "=> loss 45.72 acc 0.92\n",
      "=> loss 50.32 acc 0.88\n",
      "=> loss 57.01 acc 0.85\n",
      "=> loss 52.93 acc 0.88\n",
      "=> loss 56.05 acc 0.87\n",
      "=> loss 61.58 acc 0.85\n",
      "=> loss 47.79 acc 0.89\n",
      "=> loss 54.13 acc 0.87\n",
      "=> loss 57.36 acc 0.87\n",
      "=> loss 62.06 acc 0.86\n",
      "=> loss 55.85 acc 0.86\n",
      "=> loss 62.17 acc 0.84\n",
      "=> loss 50.05 acc 0.89\n",
      "=> loss 56.72 acc 0.87\n",
      "=> loss 46.38 acc 0.89\n",
      "=> loss 49.10 acc 0.88\n",
      "=> loss 49.17 acc 0.90\n",
      "=> loss 51.10 acc 0.89\n",
      "=> loss 54.57 acc 0.86\n",
      "=> loss 44.41 acc 0.90\n",
      "=> loss 52.87 acc 0.87\n",
      "=> loss 64.05 acc 0.82\n",
      "=> loss 53.97 acc 0.87\n",
      "=> loss 63.43 acc 0.82\n",
      "=> loss 33.79 acc 0.95\n",
      "=> loss 60.66 acc 0.82\n",
      "=> loss 59.98 acc 0.84\n",
      "=> loss 51.33 acc 0.89\n",
      "=> loss 56.61 acc 0.86\n",
      "=> loss 49.01 acc 0.89\n",
      "=> loss 57.31 acc 0.87\n",
      "=> loss 50.45 acc 0.89\n",
      "=> loss 53.45 acc 0.88\n",
      "=> loss 42.61 acc 0.90\n",
      "=> loss 45.05 acc 0.91\n",
      "=> loss 46.82 acc 0.89\n",
      "=> loss 51.74 acc 0.88\n",
      "=> loss 59.85 acc 0.83\n",
      "=> loss 55.95 acc 0.86\n",
      "=> loss 53.84 acc 0.86\n",
      "=> loss 45.14 acc 0.89\n",
      "=> loss 53.96 acc 0.86\n",
      "=> loss 48.23 acc 0.88\n",
      "=> loss 49.62 acc 0.88\n",
      "=> loss 49.22 acc 0.89\n",
      "=> loss 41.33 acc 0.90\n",
      "=> loss 61.45 acc 0.83\n",
      "=> loss 54.19 acc 0.86\n",
      "=> loss 48.09 acc 0.86\n",
      "=> loss 46.77 acc 0.88\n",
      "=> loss 49.98 acc 0.89\n",
      "=> loss 51.37 acc 0.87\n",
      "=> loss 50.57 acc 0.89\n",
      "=> loss 61.70 acc 0.83\n",
      "=> loss 57.45 acc 0.84\n",
      "=> loss 44.15 acc 0.88\n",
      "=> loss 56.44 acc 0.85\n",
      "=> loss 57.71 acc 0.83\n",
      "=> loss 43.93 acc 0.89\n",
      "=> loss 57.54 acc 0.84\n",
      "=> loss 46.34 acc 0.89\n",
      "=> loss 36.78 acc 0.93\n",
      "=> loss 46.03 acc 0.89\n",
      "=> loss 50.24 acc 0.88\n",
      "=> loss 56.08 acc 0.85\n",
      "=> loss 41.50 acc 0.90\n",
      "=> loss 57.56 acc 0.83\n",
      "=> loss 55.26 acc 0.87\n",
      "=> loss 39.14 acc 0.91\n",
      "=> loss 47.44 acc 0.88\n",
      "=> loss 50.92 acc 0.87\n",
      "=> loss 49.87 acc 0.86\n",
      "=> loss 50.32 acc 0.88\n",
      "=> loss 42.60 acc 0.92\n",
      "=> loss 48.94 acc 0.86\n",
      "=> loss 45.29 acc 0.90\n",
      "=> loss 56.31 acc 0.85\n",
      "=> loss 54.40 acc 0.86\n",
      "=> loss 46.92 acc 0.87\n",
      "=> loss 46.15 acc 0.88\n",
      "=> loss 51.53 acc 0.87\n",
      "=> loss 48.06 acc 0.89\n",
      "=> loss 41.07 acc 0.92\n",
      "=> loss 45.47 acc 0.88\n",
      "=> loss 51.86 acc 0.85\n",
      "=> loss 44.39 acc 0.89\n",
      "=> loss 52.87 acc 0.87\n",
      "=> loss 49.75 acc 0.86\n",
      "=> loss 41.95 acc 0.89\n",
      "=> loss 49.78 acc 0.87\n",
      "=> loss 41.61 acc 0.90\n",
      "=> loss 42.42 acc 0.91\n",
      "=> loss 40.12 acc 0.91\n",
      "=> loss 44.33 acc 0.88\n",
      "=> loss 36.65 acc 0.94\n",
      "=> loss 45.28 acc 0.88\n",
      "=> loss 53.70 acc 0.84\n",
      "=> loss 47.87 acc 0.87\n",
      "=> loss 41.31 acc 0.92\n",
      "=> loss 45.42 acc 0.87\n",
      "=> loss 35.46 acc 0.93\n",
      "=> loss 39.41 acc 0.90\n",
      "=> loss 45.40 acc 0.87\n",
      "=> loss 41.30 acc 0.90\n",
      "=> loss 43.86 acc 0.88\n",
      "=> loss 51.16 acc 0.87\n",
      "=> loss 38.22 acc 0.90\n",
      "=> loss 43.49 acc 0.89\n",
      "=> loss 46.26 acc 0.89\n",
      "=> loss 51.22 acc 0.87\n",
      "=> loss 44.76 acc 0.89\n",
      "=> loss 49.46 acc 0.87\n",
      "=> loss 39.99 acc 0.90\n",
      "=> loss 45.95 acc 0.88\n",
      "=> loss 36.29 acc 0.90\n",
      "=> loss 39.61 acc 0.90\n",
      "=> loss 38.75 acc 0.91\n",
      "=> loss 41.87 acc 0.90\n",
      "=> loss 45.73 acc 0.88\n",
      "=> loss 34.90 acc 0.92\n",
      "=> loss 43.28 acc 0.88\n",
      "=> loss 54.42 acc 0.84\n",
      "=> loss 44.44 acc 0.89\n",
      "=> loss 55.18 acc 0.84\n",
      "=> loss 25.99 acc 0.95\n",
      "=> loss 51.32 acc 0.85\n",
      "=> loss 51.02 acc 0.86\n",
      "=> loss 42.16 acc 0.89\n",
      "=> loss 46.00 acc 0.88\n",
      "=> loss 40.67 acc 0.90\n",
      "=> loss 48.81 acc 0.88\n",
      "=> loss 41.02 acc 0.90\n",
      "=> loss 44.08 acc 0.90\n",
      "=> loss 34.98 acc 0.92\n",
      "=> loss 37.42 acc 0.92\n",
      "=> loss 38.23 acc 0.91\n",
      "=> loss 43.01 acc 0.90\n",
      "=> loss 52.98 acc 0.85\n",
      "=> loss 46.87 acc 0.88\n",
      "=> loss 46.28 acc 0.87\n",
      "=> loss 37.06 acc 0.90\n",
      "=> loss 45.78 acc 0.87\n",
      "=> loss 40.37 acc 0.90\n",
      "=> loss 42.17 acc 0.89\n",
      "=> loss 41.78 acc 0.89\n",
      "=> loss 34.06 acc 0.92\n",
      "=> loss 53.61 acc 0.84\n",
      "=> loss 46.44 acc 0.87\n",
      "=> loss 41.33 acc 0.87\n",
      "=> loss 40.24 acc 0.89\n",
      "=> loss 41.10 acc 0.91\n",
      "=> loss 44.26 acc 0.89\n",
      "=> loss 42.56 acc 0.90\n",
      "=> loss 53.60 acc 0.86\n",
      "=> loss 48.93 acc 0.87\n",
      "=> loss 37.77 acc 0.89\n",
      "=> loss 48.58 acc 0.87\n",
      "=> loss 51.20 acc 0.86\n",
      "=> loss 37.17 acc 0.90\n",
      "=> loss 50.63 acc 0.85\n",
      "=> loss 39.87 acc 0.89\n",
      "=> loss 29.53 acc 0.94\n",
      "=> loss 39.04 acc 0.90\n",
      "=> loss 42.96 acc 0.89\n",
      "=> loss 48.26 acc 0.86\n",
      "=> loss 34.94 acc 0.91\n",
      "=> loss 52.43 acc 0.85\n",
      "=> loss 48.92 acc 0.88\n",
      "=> loss 32.90 acc 0.93\n",
      "=> loss 42.20 acc 0.90\n",
      "=> loss 44.09 acc 0.88\n",
      "=> loss 43.93 acc 0.88\n",
      "=> loss 44.81 acc 0.88\n",
      "=> loss 35.81 acc 0.93\n",
      "=> loss 42.75 acc 0.88\n",
      "=> loss 39.61 acc 0.90\n",
      "=> loss 49.15 acc 0.88\n",
      "=> loss 46.15 acc 0.88\n",
      "=> loss 41.05 acc 0.88\n",
      "=> loss 40.41 acc 0.90\n",
      "=> loss 45.48 acc 0.88\n",
      "=> loss 41.57 acc 0.91\n",
      "=> loss 35.13 acc 0.92\n",
      "=> loss 39.37 acc 0.89\n",
      "=> loss 46.10 acc 0.86\n",
      "=> loss 38.94 acc 0.90\n",
      "=> loss 45.95 acc 0.89\n",
      "=> loss 45.07 acc 0.87\n",
      "=> loss 36.38 acc 0.91\n",
      "=> loss 44.04 acc 0.88\n",
      "=> loss 35.87 acc 0.92\n",
      "=> loss 37.15 acc 0.92\n",
      "=> loss 34.42 acc 0.93\n",
      "=> loss 37.69 acc 0.90\n",
      "=> loss 30.54 acc 0.94\n",
      "=> loss 39.70 acc 0.89\n",
      "=> loss 48.03 acc 0.86\n",
      "=> loss 41.97 acc 0.88\n",
      "=> loss 35.13 acc 0.92\n",
      "=> loss 40.56 acc 0.89\n",
      "=> loss 31.31 acc 0.93\n",
      "=> loss 34.44 acc 0.91\n",
      "=> loss 39.95 acc 0.88\n",
      "=> loss 35.92 acc 0.91\n",
      "=> loss 38.08 acc 0.89\n",
      "=> loss 46.42 acc 0.88\n",
      "=> loss 33.69 acc 0.91\n",
      "=> loss 38.69 acc 0.90\n",
      "=> loss 41.29 acc 0.89\n",
      "=> loss 46.05 acc 0.88\n",
      "=> loss 39.58 acc 0.90\n",
      "=> loss 42.92 acc 0.88\n",
      "=> loss 35.44 acc 0.91\n",
      "=> loss 40.85 acc 0.89\n",
      "=> loss 31.56 acc 0.92\n",
      "=> loss 35.08 acc 0.91\n",
      "=> loss 33.80 acc 0.92\n",
      "=> loss 37.64 acc 0.91\n",
      "=> loss 41.55 acc 0.88\n",
      "=> loss 30.42 acc 0.93\n",
      "=> loss 38.43 acc 0.89\n",
      "=> loss 49.53 acc 0.86\n",
      "=> loss 39.75 acc 0.90\n",
      "=> loss 51.12 acc 0.86\n",
      "=> loss 22.56 acc 0.95\n",
      "=> loss 46.48 acc 0.87\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> loss 46.59 acc 0.87\n",
      "=> loss 37.51 acc 0.90\n",
      "=> loss 40.45 acc 0.89\n",
      "=> loss 36.56 acc 0.91\n",
      "=> loss 44.61 acc 0.89\n",
      "=> loss 36.38 acc 0.92\n",
      "=> loss 39.42 acc 0.90\n",
      "=> loss 31.19 acc 0.92\n",
      "=> loss 33.75 acc 0.92\n",
      "=> loss 33.85 acc 0.92\n",
      "=> loss 38.62 acc 0.91\n",
      "=> loss 49.52 acc 0.86\n",
      "=> loss 41.80 acc 0.89\n",
      "=> loss 42.63 acc 0.88\n",
      "=> loss 32.95 acc 0.91\n",
      "=> loss 41.53 acc 0.88\n",
      "=> loss 36.28 acc 0.91\n",
      "=> loss 38.24 acc 0.90\n",
      "=> loss 38.05 acc 0.90\n",
      "=> loss 30.17 acc 0.92\n",
      "=> loss 49.28 acc 0.85\n",
      "=> loss 42.25 acc 0.88\n",
      "=> loss 37.91 acc 0.88\n",
      "=> loss 36.97 acc 0.90\n",
      "=> loss 36.47 acc 0.92\n",
      "=> loss 40.56 acc 0.89\n",
      "=> loss 38.35 acc 0.90\n",
      "=> loss 49.30 acc 0.86\n",
      "=> loss 44.14 acc 0.88\n",
      "=> loss 34.42 acc 0.90\n",
      "=> loss 44.10 acc 0.88\n",
      "=> loss 47.79 acc 0.87\n",
      "=> loss 33.56 acc 0.91\n",
      "=> loss 46.88 acc 0.87\n",
      "=> loss 36.54 acc 0.90\n",
      "=> loss 25.69 acc 0.94\n",
      "=> loss 35.17 acc 0.92\n",
      "=> loss 38.99 acc 0.90\n",
      "=> loss 43.83 acc 0.88\n",
      "=> loss 31.36 acc 0.91\n",
      "=> loss 49.78 acc 0.86\n",
      "=> loss 45.50 acc 0.88\n",
      "=> loss 29.64 acc 0.93\n",
      "=> loss 39.29 acc 0.91\n",
      "=> loss 40.22 acc 0.88\n",
      "=> loss 40.71 acc 0.88\n",
      "=> loss 41.89 acc 0.89\n",
      "=> loss 32.12 acc 0.93\n",
      "=> loss 39.28 acc 0.88\n",
      "=> loss 36.44 acc 0.91\n",
      "=> loss 45.04 acc 0.88\n",
      "=> loss 41.43 acc 0.89\n",
      "=> loss 37.76 acc 0.89\n",
      "=> loss 37.24 acc 0.91\n",
      "=> loss 42.04 acc 0.89\n",
      "=> loss 37.80 acc 0.91\n",
      "=> loss 31.73 acc 0.93\n",
      "=> loss 35.88 acc 0.90\n",
      "=> loss 42.87 acc 0.87\n",
      "=> loss 35.98 acc 0.92\n",
      "=> loss 41.85 acc 0.90\n",
      "=> loss 42.61 acc 0.88\n",
      "=> loss 33.21 acc 0.92\n",
      "=> loss 40.62 acc 0.89\n",
      "=> loss 32.65 acc 0.93\n",
      "=> loss 34.17 acc 0.92\n",
      "=> loss 31.11 acc 0.93\n",
      "=> loss 33.79 acc 0.90\n",
      "=> loss 27.10 acc 0.95\n",
      "=> loss 36.41 acc 0.90\n",
      "=> loss 44.59 acc 0.88\n",
      "=> loss 38.49 acc 0.89\n",
      "=> loss 31.51 acc 0.92\n",
      "=> loss 37.70 acc 0.89\n",
      "=> loss 29.14 acc 0.94\n",
      "=> loss 31.50 acc 0.91\n",
      "=> loss 36.67 acc 0.88\n",
      "=> loss 32.71 acc 0.92\n",
      "=> loss 34.59 acc 0.90\n",
      "=> loss 43.66 acc 0.88\n",
      "=> loss 30.90 acc 0.92\n",
      "=> loss 35.90 acc 0.90\n",
      "=> loss 38.46 acc 0.90\n",
      "=> loss 42.97 acc 0.88\n",
      "=> loss 36.53 acc 0.91\n",
      "=> loss 38.85 acc 0.90\n",
      "=> loss 32.79 acc 0.92\n",
      "=> loss 37.85 acc 0.89\n",
      "=> loss 28.76 acc 0.92\n",
      "=> loss 32.32 acc 0.92\n",
      "=> loss 30.88 acc 0.93\n",
      "=> loss 35.19 acc 0.91\n",
      "=> loss 39.08 acc 0.89\n",
      "=> loss 27.78 acc 0.94\n",
      "=> loss 35.38 acc 0.90\n",
      "=> loss 46.54 acc 0.87\n",
      "=> loss 36.84 acc 0.91\n",
      "=> loss 48.61 acc 0.87\n",
      "=> loss 20.63 acc 0.96\n",
      "=> loss 43.42 acc 0.88\n",
      "=> loss 43.86 acc 0.87\n",
      "=> loss 34.61 acc 0.91\n",
      "=> loss 36.99 acc 0.90\n",
      "=> loss 34.06 acc 0.91\n",
      "=> loss 42.05 acc 0.89\n",
      "=> loss 33.59 acc 0.92\n",
      "=> loss 36.63 acc 0.90\n",
      "=> loss 28.85 acc 0.93\n",
      "=> loss 31.54 acc 0.92\n",
      "=> loss 31.16 acc 0.92\n",
      "=> loss 35.96 acc 0.92\n",
      "=> loss 47.38 acc 0.87\n",
      "=> loss 38.47 acc 0.90\n",
      "=> loss 40.44 acc 0.88\n",
      "=> loss 30.42 acc 0.92\n",
      "=> loss 38.90 acc 0.88\n",
      "=> loss 33.71 acc 0.91\n",
      "=> loss 35.76 acc 0.90\n",
      "=> loss 35.79 acc 0.90\n",
      "=> loss 27.67 acc 0.93\n",
      "=> loss 46.47 acc 0.86\n",
      "=> loss 39.55 acc 0.88\n",
      "=> loss 35.77 acc 0.89\n",
      "=> loss 34.98 acc 0.90\n",
      "=> loss 33.62 acc 0.92\n",
      "=> loss 38.26 acc 0.90\n",
      "=> loss 35.72 acc 0.91\n",
      "=> loss 46.59 acc 0.88\n",
      "=> loss 41.00 acc 0.89\n",
      "=> loss 32.27 acc 0.90\n",
      "=> loss 41.11 acc 0.88\n",
      "=> loss 45.66 acc 0.87\n",
      "=> loss 31.25 acc 0.92\n",
      "=> loss 44.48 acc 0.87\n",
      "=> loss 34.49 acc 0.91\n",
      "=> loss 23.23 acc 0.94\n",
      "=> loss 32.70 acc 0.92\n",
      "=> loss 36.47 acc 0.91\n",
      "=> loss 40.97 acc 0.88\n",
      "=> loss 29.08 acc 0.92\n",
      "=> loss 48.13 acc 0.87\n",
      "=> loss 43.31 acc 0.89\n",
      "=> loss 27.61 acc 0.93\n",
      "=> loss 37.39 acc 0.91\n",
      "=> loss 37.67 acc 0.89\n",
      "=> loss 38.65 acc 0.88\n",
      "=> loss 40.05 acc 0.89\n",
      "=> loss 29.80 acc 0.94\n",
      "=> loss 37.00 acc 0.89\n",
      "=> loss 34.35 acc 0.91\n",
      "=> loss 42.31 acc 0.88\n",
      "=> loss 38.35 acc 0.90\n",
      "=> loss 35.60 acc 0.90\n",
      "=> loss 35.21 acc 0.91\n",
      "=> loss 39.74 acc 0.89\n",
      "=> loss 35.28 acc 0.92\n",
      "=> loss 29.49 acc 0.93\n",
      "=> loss 33.58 acc 0.90\n",
      "=> loss 40.73 acc 0.87\n",
      "=> loss 34.12 acc 0.92\n",
      "=> loss 39.09 acc 0.90\n",
      "=> loss 41.10 acc 0.89\n",
      "=> loss 31.12 acc 0.93\n",
      "=> loss 38.32 acc 0.90\n",
      "=> loss 30.58 acc 0.93\n",
      "=> loss 32.23 acc 0.92\n",
      "=> loss 28.91 acc 0.93\n",
      "=> loss 31.15 acc 0.91\n",
      "=> loss 24.86 acc 0.95\n",
      "=> loss 34.19 acc 0.90\n",
      "=> loss 42.22 acc 0.88\n",
      "=> loss 36.15 acc 0.89\n",
      "=> loss 29.11 acc 0.92\n",
      "=> loss 35.78 acc 0.90\n",
      "=> loss 27.83 acc 0.94\n",
      "=> loss 29.52 acc 0.92\n",
      "=> loss 34.45 acc 0.89\n",
      "=> loss 30.54 acc 0.92\n",
      "=> loss 32.22 acc 0.90\n",
      "=> loss 41.83 acc 0.88\n",
      "=> loss 28.98 acc 0.92\n",
      "=> loss 34.06 acc 0.91\n",
      "=> loss 36.63 acc 0.90\n",
      "=> loss 40.88 acc 0.89\n",
      "=> loss 34.51 acc 0.91\n",
      "=> loss 36.07 acc 0.90\n",
      "=> loss 31.02 acc 0.92\n",
      "=> loss 35.88 acc 0.90\n",
      "=> loss 26.88 acc 0.93\n",
      "=> loss 30.43 acc 0.92\n",
      "=> loss 28.93 acc 0.93\n",
      "=> loss 33.57 acc 0.91\n",
      "=> loss 37.46 acc 0.90\n",
      "=> loss 26.03 acc 0.94\n",
      "=> loss 33.24 acc 0.91\n",
      "=> loss 44.53 acc 0.87\n",
      "=> loss 34.82 acc 0.92\n",
      "=> loss 46.88 acc 0.88\n",
      "=> loss 19.37 acc 0.96\n",
      "=> loss 41.27 acc 0.89\n",
      "=> loss 41.96 acc 0.87\n",
      "=> loss 32.61 acc 0.91\n",
      "=> loss 34.59 acc 0.91\n",
      "=> loss 32.34 acc 0.92\n",
      "=> loss 40.30 acc 0.90\n",
      "=> loss 31.70 acc 0.92\n",
      "=> loss 34.78 acc 0.91\n",
      "=> loss 27.23 acc 0.93\n",
      "=> loss 30.04 acc 0.92\n",
      "=> loss 29.31 acc 0.92\n",
      "=> loss 34.18 acc 0.92\n",
      "=> loss 45.88 acc 0.87\n",
      "=> loss 36.09 acc 0.90\n",
      "=> loss 38.93 acc 0.89\n",
      "=> loss 28.71 acc 0.92\n",
      "=> loss 37.11 acc 0.89\n",
      "=> loss 31.92 acc 0.92\n",
      "=> loss 34.03 acc 0.91\n",
      "=> loss 34.27 acc 0.90\n",
      "=> loss 25.90 acc 0.93\n",
      "=> loss 44.47 acc 0.88\n",
      "=> loss 37.66 acc 0.89\n",
      "=> loss 34.27 acc 0.90\n",
      "=> loss 33.62 acc 0.90\n",
      "=> loss 31.66 acc 0.92\n",
      "=> loss 36.68 acc 0.90\n",
      "=> loss 33.89 acc 0.91\n",
      "=> loss 44.71 acc 0.88\n",
      "=> loss 38.75 acc 0.89\n",
      "=> loss 30.73 acc 0.90\n",
      "=> loss 38.93 acc 0.89\n",
      "=> loss 44.21 acc 0.88\n",
      "=> loss 29.60 acc 0.92\n",
      "=> loss 42.80 acc 0.87\n",
      "=> loss 33.08 acc 0.92\n",
      "=> loss 21.51 acc 0.94\n",
      "=> loss 30.99 acc 0.93\n",
      "=> loss 34.72 acc 0.91\n"
     ]
    }
   ],
   "source": [
    "# Construct model\n",
    "model = lambda x: tf.nn.softmax(tf.matmul(x, W) + b) # Softmax\n",
    "# Minimize error using cross entropy\n",
    "compute_loss = lambda true, pred: tf.reduce_mean(tf.reduce_sum(tf.losses.binary_crossentropy(true, pred), axis=-1))\n",
    "# caculate accuracy\n",
    "compute_accuracy = lambda true, pred: tf.reduce_mean(tf.keras.metrics.categorical_accuracy(true, pred))\n",
    "# Gradient Descent\n",
    "optimizer = tf.optimizers.Adam(learning_rate)\n",
    "\n",
    "for epoch in range(training_epochs):\n",
    "    for i, (x_, y_) in enumerate(train_dataset):\n",
    "        with tf.GradientTape() as tape:\n",
    "            pred = model(x_)\n",
    "            loss = compute_loss(y_, pred)\n",
    "        acc = compute_accuracy(y_, pred)\n",
    "        grads = tape.gradient(loss, [W, b])\n",
    "        optimizer.apply_gradients(zip(grads, [W, b]))\n",
    "        print(\"=> loss %.2f acc %.2f\" %(loss.numpy(), acc.numpy()))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
